# src/evaluation/visualize.py
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.manifold import TSNE

def plot_confusion_matrix(cm, class_names, output_dir=".", file_name="confusion_matrix.png"):
    """
    Plots and saves a confusion matrix.

    Args:
        cm (numpy.ndarray): The calculated confusion matrix.
        class_names (list): List of class names to be displayed on the axes.
        output_dir (str): Directory where the image will be saved.
        file_name (str): Filename for the saved image.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    if not isinstance(cm, np.ndarray):
        cm = np.array(cm)

    # Dynamically adjust figure size based on the number of classes
    fig_width = max(6, len(class_names) * 0.8)
    fig_height = max(5, len(class_names) * 0.7)
    
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='vertical', colorbar=True)
    ax.set_title("Confusion Matrix", fontsize=14)
    ax.set_xlabel("Predicted label", fontsize=12)
    ax.set_ylabel("True label", fontsize=12)
    plt.tight_layout() # Adjust layout to prevent labels from being cut off
    
    output_path = os.path.join(output_dir, file_name)
    try:
        plt.savefig(output_path)
    except Exception as e:
        print(f"Error saving confusion matrix to {output_path}: {e}")
    finally:
        plt.close(fig) # Prevent memory leaks

def plot_training_history(history, num_epochs, output_dir=".", file_name="training_history.png"):
    """
    Plots and saves the training process (loss, accuracy, F1-score).

    Args:
        history (dict): A dictionary containing training metrics.
                        e.g., {'train_loss': [...], 'val_accuracy': [...], 'val_f1': [...]}
        num_epochs (int): The total number of epochs run.
        output_dir (str): Directory where the image will be saved.
        file_name (str): Filename for the saved image.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # Filter out None values to get the actual number of recorded epochs
    train_loss_data = [item for item in history.get("train_loss", []) if item is not None]
    val_accuracy_data = [item for item in history.get("val_accuracy", []) if item is not None]
    val_f1_data = [item for item in history.get("val_f1", []) if item is not None]

    # X-axis starts from 1
    epochs_range_train = range(1, len(train_loss_data) + 1)
    epochs_range_val = range(1, max(len(val_accuracy_data), len(val_f1_data)) + 1)

    fig = plt.figure(figsize=(12, 5))
    has_subplot1 = False
    has_subplot2 = False

    # Plot Training Loss
    if train_loss_data:
        ax1 = fig.add_subplot(1, 2, 1) if (val_accuracy_data or val_f1_data) else fig.add_subplot(1, 1, 1)
        ax1.plot(epochs_range_train, train_loss_data, label='Training Loss', marker='.', color='tab:blue')
        ax1.set_title('Training Loss Over Epochs', fontsize=14)
        ax1.set_xlabel('Epoch', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.grid(True)
        ax1.legend()
        has_subplot1 = True

    # Plot Validation Metrics
    if val_accuracy_data or val_f1_data:
        ax2 = fig.add_subplot(1, 2, 2) if has_subplot1 else fig.add_subplot(1, 1, 1)
        if val_accuracy_data:
            ax2.plot(epochs_range_val, val_accuracy_data, label='Validation Accuracy', marker='o', linestyle='--', color='tab:green')
        if val_f1_data:
            ax2.plot(epochs_range_val, val_f1_data, label='Validation F1-Score', marker='x', linestyle=':', color='tab:red')
        
        ax2.set_title('Validation Metrics Over Epochs', fontsize=14)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Metric Value', fontsize=12)
        ax2.grid(True)
        ax2.legend()
        has_subplot2 = True
    
    if not has_subplot1 and not has_subplot2:
        fig.text(0.5, 0.5, "No data to plot for training history.", 
                 horizontalalignment='center', verticalalignment='center', fontsize=12)

    if has_subplot1 or has_subplot2:
        plt.tight_layout()
        output_path = os.path.join(output_dir, file_name)
        try:
            plt.savefig(output_path)
        except Exception as e:
            print(f"Error saving training history plot to {output_path}: {e}")
    else:
        print("No data available to plot training history.")
        
    plt.close(fig)

def plot_tsne_embeddings(features, labels, class_names, 
                         output_dir=".", file_name="tsne_embeddings.png", 
                         title="t-SNE Visualization of Features", perplexity=30.0):
    """
    Visualizes and saves high-dimensional features in 2D using t-SNE.

    Args:
        features (np.array): Array of features to visualize (n_samples, feature_dim).
        labels (np.array): True labels for each sample (n_samples,). Must be integers.
        class_names (list): The actual class name corresponding to each label index.
        output_dir (str): Directory where the image will be saved.
        file_name (str): Filename for the saved image.
        title (str): The title of the plot.
        perplexity (float): Perplexity value for t-SNE. Must be less than n_samples.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    num_samples = features.shape[0]
    
    if num_samples <= 1:
        print(f"t-SNE Warning: Too few samples ({num_samples}) to visualize. Skipping.")
        return
    
    # Perplexity must be less than n_samples. Common range is 5-50.
    actual_perplexity = min(float(perplexity), float(num_samples - 1))
    if actual_perplexity < 1.0: # Perplexity must be positive
        print(f"t-SNE Info: Perplexity was too large for the number of samples, adjusted to {actual_perplexity}.")

    if actual_perplexity <= 0:
        print(f"t-SNE Warning: Cannot get a valid perplexity ({actual_perplexity}) for n_samples={num_samples}. Skipping.")
        return

    print(f"Running t-SNE (n_samples: {num_samples}, perplexity: {actual_perplexity})...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=actual_perplexity,
                learning_rate='auto', init='pca', n_iter=1000, metric='cosine')
    
    try:
        tsne_results = tsne.fit_transform(features)
    except Exception as e:
        print(f"t-SNE fit_transform error: {e}. Perplexity={actual_perplexity}, n_samples={num_samples}. Skipping visualization.")
        return

    unique_labels = np.unique(labels)
    num_unique_labels = len(unique_labels)
    
    plt.figure(figsize=(12, 10))
    
    # Use a colormap to assign different colors to each class
    cmap = plt.cm.get_cmap("jet", num_unique_labels)
    
    for i, label_val in enumerate(unique_labels):
        idx = (labels == label_val)
        plt.scatter(tsne_results[idx, 0], tsne_results[idx, 1], 
                    color=cmap(i),
                    label=class_names[label_val] if class_names and label_val < len(class_names) else f'Class {label_val}', 
                    alpha=0.7, s=50)
    
    plt.title(title, fontsize=16)
    plt.xlabel("t-SNE Dimension 1", fontsize=12)
    plt.ylabel("t-SNE Dimension 2", fontsize=12)
    
    if num_unique_labels > 0:
        plt.legend(title="Classes", bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust rect to make space for legend

    output_path = os.path.join(output_dir, file_name)
    try:
        plt.savefig(output_path)
        print(f"t-SNE visualization saved to {output_path}")
    except Exception as e:
        print(f"Error saving t-SNE visualization to {output_path}: {e}")
    finally:
        plt.close()